# Copyright 2020 Makani Technologies LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""C code generation backend."""

import re
import textwrap

from makani.lib.python import string_util
from makani.lib.python.pack2 import backend


class BackendC(backend.Backend):
  """C code generation backend."""

  _primary_type_map = {
      'uint8': 'uint8_t',
      'int8': 'int8_t',
      'uint16': 'uint16_t',
      'int16': 'int16_t',
      'uint32': 'uint32_t',
      'int32': 'int32_t',
      'float32': 'float',
      'date': 'uint32_t',
  }

  _param_section_map = {
      'Config': '&ldscript_config_param_data',
      'Calib': '&ldscript_calib_param_data',
      'Serial': '&ldscript_serial_param_data',
  }

  def __init__(self, header_path):
    super(self.__class__, self).__init__()

    self.header_path = header_path

    self._StartSource()
    self._StartHeader()

  def _StartSource(self):
    self.source_string = textwrap.dedent("""\
        // This file is automatically generated.  Do not edit.
        #include "{header_path}"

        #include <stdint.h>

        #ifdef PACK2_FLASH_POINTERS
        #include "avionics/firmware/startup/ldscript.h"
        #endif
        """).format(header_path=self.header_path)

  def _FinalizeSource(self):
    pass

  def _StartHeader(self):
    self.header_string = textwrap.dedent("""\
        // This file is automatically generated.  Do not edit.
        #ifndef {guard}
        #define {guard}

        #include <stdint.h>

        """).format(guard=self._HeaderGuard())

  def _FinalizeHeader(self):
    self.header_string += textwrap.dedent("""\
        #endif  // {guard}
        """).format(guard=self._HeaderGuard())

  def _HeaderGuard(self):
    return re.sub('[/.]', '_', self.header_path).upper() + '_'

  def AddInclude(self, path):
    self.header_string += '#include "%s.h"\n' % path

  def AddBitfield(self, bitfield):
    raise NotImplementedError('Bitfields not implemented for %s'
                              % self.__class__.__name__)

  def AddEnum(self, enum):
    header = 'typedef enum {\n'

    values = sorted(enum.body.value_map.keys())

    needs_sign_force = not [v for v in values if v < 0]
    if needs_sign_force:
      header += '  k{name}ForceSigned = -1,\n'.format(name=enum.name)

    for value in values:
      # C enum value names are of the form kEnumNameValueName.
      value_name = enum.body.value_map[value]
      name = 'k' + enum.name + value_name
      header += '  {name} = {value},\n'.format(name=name, value=value)

    bits = enum.width * 8 - 1
    max_value = (1 << bits) - 1

    if values[0] == 0 and values[-1] + 1 == len(values):
      plural_name = string_util.GetPlural(enum.name)
      header += '  kNum{name} = {num},\n'.format(name=plural_name,
                                                 num=len(values))

    if max_value not in values:
      header += '  k{name}ForceSize = 0x{val:x},\n'.format(name=enum.name,
                                                           val=max_value)

    header += textwrap.dedent("""\
        }} __attribute__((packed)) {name};

        """).format(name=enum.name)
    self.header_string += header

  def AddStruct(self, struct):
    self.header_string += 'typedef struct {\n'
    for field in struct.body.fields:
      type_name = field.type_obj.name
      if type_name in self._primary_type_map:
        type_name = self._primary_type_map[type_name]

      if type_name == 'string':
        self.header_string += '  char {name}[{size}];\n'.format(
            name=field.name, size=field.type_obj.width)
      elif field.extent == 1:
        self.header_string += '  {type_name} {name};\n'.format(
            type_name=type_name, name=field.name)
      else:
        self.header_string += '  {type_name} {name}[{extent}];\n'.format(
            type_name=type_name, name=field.name, extent=field.extent)
    self.header_string += '}} {type_name};\n\n'.format(type_name=struct.name)

  def AddScaled(self, bitfield):
    raise NotImplementedError('Scaleds not implemented for %s'
                              % self.__class__.__name__)

  def AddHeader(self, header):
    self.AddStruct(header)

  def AddParam(self, param):
    match = re.search(r'(Config|Calib|Serial)Params(V[0-9]+)?$', param.name)
    if not match:
      raise ValueError("Can't determine params type from name %s" % param.name)
    param_type = match.group(1)

    self.AddStruct(param)
    self.header_string += textwrap.dedent("""\
        #ifdef PACK2_FLASH_POINTERS
        extern const {type_name} *k{type_name};
        #endif
        static const uint32_t k{type_name}Crc = 0x{crc:08x};
        static inline uint32_t {type_name}GetTypeVersion(void) {{
            return k{type_name}Crc;
        }}

        """).format(type_name=param.name,
                    crc=param.Crc32())

    # Make sure we don't redefine the base param symbols.
    if param.name not in ['ConfigParams', 'CalibPrams', 'SerialParams']:
      self.source_string += textwrap.dedent("""\
          #ifdef PACK2_FLASH_POINTERS
          const {type_name} *k{type_name} = {section};
          #endif
          """).format(type_name=param.name,
                      section=self._param_section_map[param_type])

  def Finalize(self):
    self._FinalizeSource()
    self._FinalizeHeader()

  def GetSourceString(self, name):
    if name == 'header':
      return self.header_string
    elif name == 'source':
      return self.source_string
    else:
      raise ValueError('Unknown source %s.' % name)
